{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Overfitting and regularization (with ``gluon``)\n",
    "\n",
    "Now that we've built a [regularized logistic regression model from scratch](regularization-scratch.html), let's make this more efficient with ``gluon``. We recommend that you read the latter for a description as to why regularization is a good idea. As always, we begin by loading libraries and some data.\n",
    "\n",
    "[**REFINED DRAFT - RELEASE STAGE: CATFOOD**]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import mxnet as mx\n",
    "from mxnet import autograd\n",
    "from mxnet import gluon\n",
    "import mxnet.ndarray as nd\n",
    "import numpy as np\n",
    "ctx = mx.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The MNIST Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "mnist = mx.test_utils.get_mnist()\n",
    "num_examples = 1000\n",
    "batch_size = 64\n",
    "train_data = mx.gluon.data.DataLoader(\n",
    "    mx.gluon.data.ArrayDataset(mnist[\"train_data\"][:num_examples],\n",
    "                               mnist[\"train_label\"][:num_examples].astype(np.float32)), \n",
    "                               batch_size, shuffle=True)\n",
    "test_data = mx.gluon.data.DataLoader(\n",
    "    mx.gluon.data.ArrayDataset(mnist[\"test_data\"][:num_examples],\n",
    "                               mnist[\"test_label\"][:num_examples].astype(np.float32)), \n",
    "                               batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multiclass Logistic Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = gluon.nn.Sequential()\n",
    "with net.name_scope():\n",
    "    net.add(gluon.nn.Dense(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameter initialization\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Softmax Cross Entropy Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "loss = gluon.loss.SoftmaxCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizer\n",
    "\n",
    "By default ``gluon`` tries to keep the coefficients from diverging by using a *weight decay* penalty. So, to get the real overfitting experience we need to switch it off. We do this by passing `'wd': 0.0'` when we instantiate the trainer. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01, 'wd': 0.0})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation Metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def evaluate_accuracy(data_iterator, net):\n",
    "    numerator = 0.\n",
    "    denominator = 0.\n",
    "    for i, (data, label) in enumerate(data_iterator):\n",
    "        data = data.as_in_context(ctx).reshape((-1,784))\n",
    "        label = label.as_in_context(ctx)\n",
    "        label_one_hot = nd.one_hot(label, 10)\n",
    "        output = net(data)        \n",
    "        predictions = nd.argmax(output, axis=1)\n",
    "        numerator += nd.sum(predictions == label)\n",
    "        denominator += data.shape[0]\n",
    "    return (numerator / denominator).asscalar()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Execute training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Completed epoch 0. Loss: 2.41042828249, Train_acc 0.183, Test_acc 0.163\n",
      "Completed epoch 20. Loss: 0.980734459237, Train_acc 0.823, Test_acc 0.734\n",
      "Completed epoch 40. Loss: 0.656826479353, Train_acc 0.863, Test_acc 0.762\n",
      "Completed epoch 60. Loss: 0.547160148674, Train_acc 0.879, Test_acc 0.776\n",
      "Completed epoch 80. Loss: 0.410905732785, Train_acc 0.894, Test_acc 0.8\n",
      "Completed epoch 100. Loss: 0.5042345608, Train_acc 0.902, Test_acc 0.813\n",
      "Completed epoch 120. Loss: 0.393295650603, Train_acc 0.908, Test_acc 0.817\n",
      "Completed epoch 140. Loss: 0.31935924206, Train_acc 0.915, Test_acc 0.819\n",
      "Completed epoch 160. Loss: 0.353683424529, Train_acc 0.922, Test_acc 0.826\n",
      "Completed epoch 180. Loss: 0.441564063282, Train_acc 0.925, Test_acc 0.826\n",
      "Completed epoch 200. Loss: 0.439105920901, Train_acc 0.928, Test_acc 0.827\n",
      "Completed epoch 220. Loss: 0.335929108585, Train_acc 0.93, Test_acc 0.831\n",
      "Completed epoch 240. Loss: 0.341114005655, Train_acc 0.933, Test_acc 0.833\n",
      "Completed epoch 260. Loss: 0.265167222471, Train_acc 0.935, Test_acc 0.835\n",
      "Completed epoch 280. Loss: 0.225616094168, Train_acc 0.937, Test_acc 0.836\n",
      "Completed epoch 300. Loss: 0.31287811555, Train_acc 0.939, Test_acc 0.836\n",
      "Completed epoch 320. Loss: 0.201846450408, Train_acc 0.941, Test_acc 0.836\n",
      "Completed epoch 340. Loss: 0.26857858658, Train_acc 0.946, Test_acc 0.838\n",
      "Completed epoch 360. Loss: 0.250789181783, Train_acc 0.949, Test_acc 0.839\n",
      "Completed epoch 380. Loss: 0.292163310882, Train_acc 0.949, Test_acc 0.839\n",
      "Completed epoch 400. Loss: 0.364868967968, Train_acc 0.95, Test_acc 0.842\n",
      "Completed epoch 420. Loss: 0.16974889173, Train_acc 0.952, Test_acc 0.843\n",
      "Completed epoch 440. Loss: 0.20597298219, Train_acc 0.954, Test_acc 0.845\n",
      "Completed epoch 460. Loss: 0.135431291838, Train_acc 0.955, Test_acc 0.848\n",
      "Completed epoch 480. Loss: 0.192532801942, Train_acc 0.96, Test_acc 0.849\n",
      "Completed epoch 500. Loss: 0.139670844177, Train_acc 0.961, Test_acc 0.848\n",
      "Completed epoch 520. Loss: 0.188918106269, Train_acc 0.962, Test_acc 0.85\n",
      "Completed epoch 540. Loss: 0.182555058185, Train_acc 0.964, Test_acc 0.848\n",
      "Completed epoch 560. Loss: 0.149213267695, Train_acc 0.965, Test_acc 0.847\n",
      "Completed epoch 580. Loss: 0.215244527124, Train_acc 0.965, Test_acc 0.844\n",
      "Completed epoch 600. Loss: 0.220504981679, Train_acc 0.966, Test_acc 0.842\n",
      "Completed epoch 620. Loss: 0.225874061765, Train_acc 0.966, Test_acc 0.841\n",
      "Completed epoch 640. Loss: 0.105731161028, Train_acc 0.967, Test_acc 0.841\n",
      "Completed epoch 660. Loss: 0.229186017501, Train_acc 0.968, Test_acc 0.841\n",
      "Completed epoch 680. Loss: 0.195402703398, Train_acc 0.969, Test_acc 0.842\n"
     ]
    }
   ],
   "source": [
    "epochs = 700\n",
    "moving_loss = 0.\n",
    "\n",
    "for e in range(epochs):\n",
    "    for i, (data, label) in enumerate(train_data):\n",
    "        data = data.as_in_context(ctx).reshape((-1,784))\n",
    "        label = label.as_in_context(ctx)\n",
    "        with autograd.record():\n",
    "            output = net(data)\n",
    "            cross_entropy = loss(output, label)\n",
    "        cross_entropy.backward()\n",
    "        trainer.step(data.shape[0])\n",
    "        \n",
    "        ##########################\n",
    "        #  Keep a moving average of the losses\n",
    "        ##########################\n",
    "        if i == 0:\n",
    "            moving_loss = nd.mean(cross_entropy).asscalar()\n",
    "        else:\n",
    "            moving_loss = .99 * moving_loss + .01 * nd.mean(cross_entropy).asscalar()\n",
    "            \n",
    "    test_accuracy = evaluate_accuracy(test_data, net)\n",
    "    train_accuracy = evaluate_accuracy(train_data, net)\n",
    "    if e % 20 == 0:\n",
    "        print(\"Completed epoch %s. Loss: %s, Train_acc %s, Test_acc %s\" % \n",
    "              (e, moving_loss, train_accuracy, test_accuracy))           "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regularization\n",
    "\n",
    "Now let's see what this mysterious *weight decay* is all about. We begin with a bit of math. When we add an L2 penalty to the weights we are effectively adding $\\frac{\\lambda}{2} \\|w\\|^2$ to the loss. Hence, every time we compute the gradient it gets an additional $\\lambda w$ term that is added to $g_t$, since this is the very derivative of the L2 penalty. As a result we end up taking a descent step not in the direction $-\\eta g_t$ but rather in the direction $-\\eta (g_t + \\lambda w)$. This effectively shrinks $w$ at each step by $\\eta \\lambda w$, thus the name weight decay. To make this work in practice we just need to set the weight decay to something nonzero."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Completed epoch 0. Loss: 2.30056658499, Train_acc 0.255, Test_acc 0.207\n",
      "Completed epoch 20. Loss: 0.953274930743, Train_acc 0.832, Test_acc 0.733\n",
      "Completed epoch 40. Loss: 0.773029556608, Train_acc 0.861, Test_acc 0.763\n",
      "Completed epoch 60. Loss: 0.643435905558, Train_acc 0.886, Test_acc 0.782\n",
      "Completed epoch 80. Loss: 0.573625952417, Train_acc 0.892, Test_acc 0.794\n",
      "Completed epoch 100. Loss: 0.471484044873, Train_acc 0.9, Test_acc 0.804\n",
      "Completed epoch 120. Loss: 0.396960894484, Train_acc 0.909, Test_acc 0.814\n",
      "Completed epoch 140. Loss: 0.409119025299, Train_acc 0.916, Test_acc 0.821\n",
      "Completed epoch 160. Loss: 0.34956342471, Train_acc 0.918, Test_acc 0.826\n",
      "Completed epoch 180. Loss: 0.364122123705, Train_acc 0.923, Test_acc 0.83\n",
      "Completed epoch 200. Loss: 0.23629064181, Train_acc 0.927, Test_acc 0.83\n",
      "Completed epoch 220. Loss: 0.35734446598, Train_acc 0.929, Test_acc 0.832\n",
      "Completed epoch 240. Loss: 0.272298600617, Train_acc 0.931, Test_acc 0.833\n",
      "Completed epoch 260. Loss: 0.362590129895, Train_acc 0.936, Test_acc 0.832\n",
      "Completed epoch 280. Loss: 0.27461734356, Train_acc 0.938, Test_acc 0.835\n",
      "Completed epoch 300. Loss: 0.253035167053, Train_acc 0.941, Test_acc 0.835\n",
      "Completed epoch 320. Loss: 0.257027310991, Train_acc 0.942, Test_acc 0.836\n",
      "Completed epoch 340. Loss: 0.359865081517, Train_acc 0.944, Test_acc 0.836\n",
      "Completed epoch 360. Loss: 0.275248116974, Train_acc 0.945, Test_acc 0.839\n",
      "Completed epoch 380. Loss: 0.357199658944, Train_acc 0.953, Test_acc 0.841\n",
      "Completed epoch 400. Loss: 0.175214320661, Train_acc 0.955, Test_acc 0.842\n",
      "Completed epoch 420. Loss: 0.270390815553, Train_acc 0.956, Test_acc 0.841\n",
      "Completed epoch 440. Loss: 0.199450716782, Train_acc 0.956, Test_acc 0.842\n",
      "Completed epoch 460. Loss: 0.151572242711, Train_acc 0.957, Test_acc 0.842\n",
      "Completed epoch 480. Loss: 0.225918149547, Train_acc 0.96, Test_acc 0.842\n",
      "Completed epoch 500. Loss: 0.309496395041, Train_acc 0.96, Test_acc 0.841\n",
      "Completed epoch 520. Loss: 0.194795705577, Train_acc 0.961, Test_acc 0.841\n",
      "Completed epoch 540. Loss: 0.211884802227, Train_acc 0.962, Test_acc 0.844\n",
      "Completed epoch 560. Loss: 0.163623161352, Train_acc 0.962, Test_acc 0.845\n",
      "Completed epoch 580. Loss: 0.1975672145, Train_acc 0.964, Test_acc 0.845\n",
      "Completed epoch 600. Loss: 0.136671575637, Train_acc 0.966, Test_acc 0.845\n",
      "Completed epoch 620. Loss: 0.267496398274, Train_acc 0.966, Test_acc 0.845\n",
      "Completed epoch 640. Loss: 0.185523537505, Train_acc 0.966, Test_acc 0.846\n",
      "Completed epoch 660. Loss: 0.154396287643, Train_acc 0.966, Test_acc 0.848\n",
      "Completed epoch 680. Loss: 0.183919263789, Train_acc 0.966, Test_acc 0.848\n"
     ]
    }
   ],
   "source": [
    "net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx, force_reinit=True)\n",
    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01, 'wd': 0.001})\n",
    "\n",
    "moving_loss = 0.\n",
    "\n",
    "for e in range(epochs):\n",
    "    for i, (data, label) in enumerate(train_data):\n",
    "        data = data.as_in_context(ctx).reshape((-1,784))\n",
    "        label = label.as_in_context(ctx)\n",
    "        with autograd.record():\n",
    "            output = net(data)\n",
    "            cross_entropy = loss(output, label)\n",
    "        cross_entropy.backward()\n",
    "        trainer.step(data.shape[0])\n",
    "        \n",
    "        ##########################\n",
    "        #  Keep a moving average of the losses\n",
    "        ##########################\n",
    "        if i == 0:\n",
    "            moving_loss = nd.mean(cross_entropy).asscalar()\n",
    "        else:\n",
    "            moving_loss = .99 * moving_loss + .01 * nd.mean(cross_entropy).asscalar()\n",
    "            \n",
    "    test_accuracy = evaluate_accuracy(test_data, net)\n",
    "    train_accuracy = evaluate_accuracy(train_data, net)\n",
    "    if e % 20 == 0:\n",
    "        print(\"Completed epoch %s. Loss: %s, Train_acc %s, Test_acc %s\" % \n",
    "              (e, moving_loss, train_accuracy, test_accuracy))           "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the test accuracy improves a bit. Note that the amount by which it improves actually depends on the amount of weight decay. We recommend that you try and experiment with different extents of weight decay. For instance, a larger weight decay (e.g. $0.01$) will lead to inferior performance, one that's larger still ($0.1$) will lead to terrible results. This is one of the reasons why tuning parameters is quite so important in getting good experimental results in practice."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next\n",
    "[Loss functions](../chapter02_supervised-learning/loss.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
